%This script implements preliminary steps that are shared by all estimation codes
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024
%Input: model_data_fgkk.mat

%% Set parameters

%Set elasticities in the model
%Variety-level parameters
if exist('omega_star') == 0
    omega_star = 0;  %foreign export supply elasticity (FGKK est = -.002); baseline simulations use omega_star = 0.02 to avoid zeros
    sigma = 2.53;       %elasticity of substitution across varieties (FGKK est = 2.53)
    sigma_star = 1.04;  %foreign import demand elasticity (FGKK est = 1.04)
end
% Other parameters
eta = 1.53;         %elasticity of substitution across products (FGKK est = 1.53)
kappa = 1.19;       %elasticity of substitution across domestic/foreign composites (FGKK est = 1.19)
nu =  1;            %elasticity of subsitution across sector intermediate good (FGKK set nu = 1)
epsilon =  1;       %elasticity of subsitution across sector final good (FGKK set epsilon = 1)
DX = 0;             %DX = 0 for export tax revenue to be accured to foreign, and DX = 1 to be accrued to Home
DM = 1;             %DM = 0 for import tax revenue to be accured to foreign, and DM = 1 to be accrued to Home
gammaM = 1;         %Import price flat misspecification
gammaX = 1;         %Export price flat misspecification
gammaQ = 1;         %Import price flat misspecification
rhoX = 1;           %Misspecification of foreign tariffs
rhoM = 1;           %Misspecification of home tariffs

%Numerical parameters
tol_parm = 1e-4;    %tolerance level to checks in parameter inversion
tol_conv = 1e-7;    %tolerance level to check for price convergence
adj_inner = .4;     %parameter to adjust sectoral domestic price in inner loop
adj_outter = .8;    %parameter to adjust sectoral import price in outer loop (only matters when omega_star != 0)
adj_inner_g0 = .1;
adj_outter_g0 = .8;

%Testing parameters
alpha = 0.05;
critical = norminv(1 - alpha/2,0,1);

%% Import and adjust data

%load(data_path + "model_data_fgkk\model_data_fgkk.mat")
%load(data_path + "model_data_fgkk/model_data_fgkk.mat")
load(data_path + "model_data_fgkk.mat")

%Vector of dummies linking product to sector
    [N, G] = size(x_val);
    Dsg = zeros(S,G);
    for s=1:S
       Dsg(s,:) = ( hs10_naics(:,end) == naics(s) )';
    end
    Dsg = sparse(Dsg);

%Adjust trade values to have the same value as in the IO table
    m_val_sample = sparse(m_val);
    adj_m_s = PMs_by_Ms./( Dsg*sum(m_val_sample,1)' );
    adj_m_gs = spdiags( ( adj_m_s'*Dsg)' , 0, G, G);
    m_val_adj = m_val_sample*adj_m_gs;

    x_val_sample = sparse(x_val);
    adj_x_s = EXs./( Dsg*sum(x_val_sample,1)' );
    adj_x_gs = spdiags( ( adj_x_s'*Dsg)' , 0, G, G);
    x_val_adj = x_val_sample*adj_x_gs;

%Define variables
    x_ig_0 = x_val_adj;
    m_ig_0 = m_val_adj;
    Lsr = sparse( Lsr );

    tau_star_ig_0 = sparse(x_tf_counter_initial);
    tau_ig_0 = sparse(m_tf_counter_initial);
    tau_star_ig_1 = sparse(x_tf_counter_final);    
    tau_ig_1      = sparse(m_tf_counter_final);    
    
    [R, S] = size(Lsr);

%Compute minimum levels for targeting trade share
    m_ig_0_sort =  sort( reshape( m_ig_0' , N*G, 1) , 'descend') ;
    m_ig_0_cumsum = cumsum(m_ig_0_sort)/sum(m_ig_0, 'all');
    [min_val, min_loc] = min(abs(m_ig_0_cumsum - trade_target)); 
    min_imp = m_ig_0_sort(min_loc);
    
    x_ig_0_sort =  sort( reshape( x_ig_0' , N*G, 1) , 'descend') ;
    x_ig_0_cumsum = cumsum(x_ig_0_sort)/sum(x_ig_0, 'all');
    [min_val, min_loc] = min(abs(x_ig_0_cumsum - trade_target)); 
    min_exp = x_ig_0_sort(min_loc);

%% Invert parameters given initial conditions

m_g_0 = sum( (1+tau_ig_0).*m_ig_0, 1)'; 
m_s_0 = Dsg*m_g_0;

[delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, E_s_0, F_0] = ... 
    invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
    Dsg, DX, DM, omega_star, sigma, eta, kappa, sigma_star, gammaM, gammaX, tol_parm);

p_s_0 = ones(S,1);

%% Organzie outcomes for estimation

%Define vector of large varieties in our sample
    ind_Mshifts_mat = m_ig_0 > min_imp;
    ind_Xshifts_mat = x_ig_0 > min_exp;
    ind_Mshifts_vec = reshape( (  ind_Mshifts_mat  )', N*G,1);
    ind_Xshifts_vec = reshape( (  ind_Xshifts_mat  )', N*G,1);

    ind_Msample_mat1 = m_ig_0 > min_imp;
    ind_Tsample_mat1 = m_ig_0 > min_imp;
    ind_Xsample_mat1 = x_ig_0 > min_exp;

    tradestat_sample1 = [sum(m_ig_0(ind_Msample_mat1),'all')/sum(m_ig_0,'all'), ...
    sum(x_ig_0(ind_Xsample_mat1),'all')/sum(x_ig_0,'all')] ; 

%Organize data for estimation
    dln_qx_data = log(x_q1_2019_3) -log(x_q1_2018_3);
    dln_vx_data = log(x_val_2019_3)-log(x_val_2018_3);
    dln_px_data   = dln_vx_data - dln_qx_data;
    dln_qm_data = log(m_q1_2019_3) -log(m_q1_2018_3);
    dln_pm_data = log(m_valduty_2019_3)-log(m_valduty_2018_3) - dln_qm_data;
    d_rm_data = ( (m_valduty_2019_3 - m_val_2019_3)-(m_valduty_2018_3 - m_val_2018_3) )./m_valduty_2018_3;

    dXout = dln_px_data;
    dMout = dln_pm_data; 
    dTout = d_rm_data;  

    clearvars m_val* m_q1* m_tf* x_val* x_q1* x_tf* 

    %define sample for estimation: drop missing values
    ind_Msample_mat2 = isnan(dMout) + isinf(dMout) == 0;
    ind_Tsample_mat2 = isnan(dTout) + isinf(dTout) == 0;
    ind_Xsample_mat2 = isnan(dXout) + isinf(dXout) == 0;

    ind_Msample_mat_aux = ind_Msample_mat1 + ind_Msample_mat2 == 2;
    ind_Tsample_mat_aux = ind_Tsample_mat1 + ind_Tsample_mat2 == 2;
    
    ind_Msample_mat = ind_Msample_mat_aux.*ind_Tsample_mat_aux == 1;
    ind_Tsample_mat = ind_Msample_mat;

    ind_Xsample_mat = ind_Xsample_mat1 + ind_Xsample_mat2 == 2;

    ind_Msample_ig = reshape( (  ind_Msample_mat  )', N*G,1);
    ind_Tsample_ig = reshape( (  ind_Tsample_mat  )', N*G,1);
    ind_Xsample_ig = reshape( (  ind_Xsample_mat  )', N*G,1);

    NMs = sum(ind_Mshifts_vec);
    NXs = sum(ind_Xshifts_vec);
    exp_sind = [zeros(NMs,1); ones(NXs,1)];

    NM = sum(ind_Msample_ig);
    NT = sum(ind_Tsample_ig);
    NX = sum(ind_Xsample_ig);
    indM = [ones(NM,1); zeros(NT,1); zeros(NX,1)];
    indT = [zeros(NM,1); ones(NT,1); zeros(NX,1)];
    indX = [zeros(NM,1); zeros(NT,1); ones(NX,1)];

    Nsample = [NM, NT, NX];

    tradestat_sample = [sum(m_ig_0(ind_Msample_mat),'all')/sum(m_ig_0,'all'), ... 
    sum(x_ig_0(ind_Xsample_mat),'all')/sum(x_ig_0,'all')] ; 

%Define outcomes for estimation

    %Vectorize outcomes
    m_ig_0_vec = reshape( (  m_ig_0  )', N*G,1);
    x_ig_0_vec = reshape( (  x_ig_0  )', N*G,1);
    taustar_ig_0_vec = reshape( (  tau_star_ig_0  )', N*G,1);
    tau_ig_0_vec = reshape( (  tau_ig_0  )', N*G,1);
    dMout = reshape( dMout', N*G,1);
    dTout = reshape( dTout', N*G,1);
    dXout = reshape( dXout', N*G,1);
    
    %select sample
    m_ig_0_sample = m_ig_0_vec(ind_Msample_ig );
    x_ig_0_sample = x_ig_0_vec(ind_Xsample_ig);
    tau_ig_0_sample = tau_ig_0_vec(ind_Msample_ig );
    taustar_ig_0_sample = taustar_ig_0_vec(ind_Xsample_ig );
    
    dMout = dMout(ind_Msample_ig);
    dTout = dTout(ind_Tsample_ig);
    dXout = dXout(ind_Xsample_ig);
    dWout = [dMout; dTout; dXout];
    dWout_sample = dWout;

%sample stats
    display(full([Nsample]))
    display(full([100*tradestat_sample; 100*tradestat_sample1]))
   
%% Naive IV
    variety_list = (1:1:N*G)';
    
    %Imports
    id_Mshifts_vec = variety_list(ind_Mshifts_vec);
    id_Msample_vec = variety_list(ind_Msample_ig);
    [sampleM_naive, sampleM_naive_shift_loc] = ismember(id_Msample_vec, id_Mshifts_vec);

    shareIV_M = zeros(NM,NMs);
    for i=1:NM
        if sampleM_naive(i) == 1
            shareIV_M(i,sampleM_naive_shift_loc(i)) = 1;
        end
    end
    shareIV_M = shareIV_M(sampleM_naive,:);
    [NMnaive, ~] = size(shareIV_M);

    %Tariff revenue
    id_Tshifts_vec = variety_list(ind_Mshifts_vec);
    id_Tsample_vec = variety_list(ind_Tsample_ig);
    [sampleT_naive, sampleT_naive_shift_loc] = ismember(id_Tsample_vec, id_Tshifts_vec);

    shareIV_T = zeros(NT,NMs);
    for i=1:NT
        if sampleT_naive(i) == 1
            shareIV_T(i,sampleT_naive_shift_loc(i)) = 1;
        end
    end
    shareIV_T = shareIV_T(sampleT_naive,:);
    [NTnaive, ~] = size(shareIV_T);

    %Exports
    id_Xshifts_vec = variety_list(ind_Xshifts_vec);
    id_Xsample_vec = variety_list(ind_Xsample_ig);
    [sampleX_naive, sampleX_naive_shift_loc] = ismember(id_Xsample_vec, id_Xshifts_vec);

    shareIV_X = zeros(NX,NXs);
    for i=1:NX
        if sampleX_naive(i) == 1
            shareIV_X(i,sampleX_naive_shift_loc(i)) = 1;
        end
    end
    shareIV_X = shareIV_X(sampleX_naive,:);
    [NXnaive, ~] = size(shareIV_X);

    shareIV_tau = sparse([shareIV_M, zeros(NMnaive,NXs); 
                   shareIV_T, zeros(NTnaive,NXs); 
                    zeros(NXnaive,NMs), shareIV_X ]);
                
    sample_naive = [sampleM_naive; sampleT_naive; sampleX_naive];
    
    indMn = indM(sample_naive);
    indTn = indT(sample_naive);
    indXn = indX(sample_naive);

%% Define shifters for estimation
    taustar_ig_0_vec = reshape( (  tau_star_ig_0  )', N*G,1);
    tau_ig_0_vec = reshape( (  tau_ig_0  )', N*G,1);

    taustar_ig_1_vec = reshape( (  tau_star_ig_1  )', N*G,1);
    tau_ig_1_vec = reshape( (  tau_ig_1  )', N*G,1);

    dtau_star_sh = reshape( (  tau_star_ig_1 - tau_star_ig_0  )', N*G,1);
    dtau_star_sh = dtau_star_sh(ind_Xshifts_vec);

    dtau_sh = reshape( (  tau_ig_1 - tau_ig_0  )', N*G,1);
    dtau_sh = dtau_sh(ind_Mshifts_vec);

    shifters = [dtau_sh; dtau_star_sh];
    avg_tariff = mean(shifters);
    std_tariff = std(shifters);
    shiftersIV = (shifters - avg_tariff)/std_tariff;

%% Compute welfare weights
    ww_M = - 100*(1+tau_ig_0_vec).*m_ig_0_vec/F_0;
    ww_T =   100*(1+tau_ig_0_vec).*m_ig_0_vec/F_0;
    ww_X =   100*(x_ig_0_vec/F_0); 

    ww_M = ww_M(ind_Msample_ig);
    ww_T = ww_T(ind_Tsample_ig);
    ww_X = ww_X(ind_Xsample_ig);

    ww_dW = [ww_M; ww_T; ww_X];

    %weight matrices
    ww_dWmat = diag(ww_dW);
    ww_dWmat_naive = diag(ww_dW(sample_naive));

%% Clustering st erros
if test_cluster == 1
    cluster_vec=[];
end
if test_cluster == 2
   cty_list = unique(all_cty);
   sec_list = unique(hs10_naics(:,end));
   
   cluster_list = [];
   for s=1:S
       aux = [cty_list, ones(N,1)*sec_list(s)];
       cluster_list = [cluster_list; aux];
   end
   cluster_list = [(1:1:S*N)', cluster_list];
   
   cty_vec = reshape( repmat(all_cty, 1, G)', N*G, 1 );
   sec_vec = reshape( repmat(hs10_naics(:,end)', N, 1)', N*G, 1 );
   
   cluster_vec_aux = zeros(N*G,1);
   for c=1:S*N
       ind = (cty_vec == cluster_list(c, 2)).*(sec_vec == cluster_list(c, 3));
       cluster_vec_aux(ind==1) = c;
   end
    cluster_vecM = cluster_vec_aux(ind_Mshifts_vec);
    cluster_vecX = cluster_vec_aux(ind_Xshifts_vec)*10000;
    cluster_vec =  [cluster_vecM; cluster_vecX];
    
    clear aux cty_list sec_list cty_vec sec_vec
end

    clearvars shareIV_M shareIV_X shareIV_T m_ig_0_sort x_ig_0_sort m_ig_0_cumsum x_ig_0_cumsum m_ig_0_vec x_ig_0_vec ww_M ww_X ww_T ...
              ind_Msample_mat ind_Msample_mat1 ind_Msample_mat2 ind_Msample_mat_aux ind_Tsample_mat ind_Tsample_mat1 ind_Tsample_mat2 ind_Tsample_mat_aux ind_Xsample_mat ind_Xsample_mat1 ind_Xsample_mat2

    clearvars d_rm_data dln_pm_data dln_px_data dln_qm_data dln_qx_data dln_vm_data dln_vx_data taustar_ig_0_vec tau_star_ig_1 taustar_ig_1_vec tau_ig_1_vec tau_ig_1 tau_ig_0_vec


%% Compute matrix of shares in the researcher's model given initial conditions
disp('Compute model jacobian and IVs')

%compute shares
    [p_s_IV, PM_s_IV, ...
             rho_qm_ig_tau_ig, rho_qm_ig_taustar_ig, rho_qm_ig_a_ig, rho_qm_ig_zstar_ig, rho_qm_ig_astar_ig, rho_qm_ig_delta_ig, ...
             rho_pm_ig_tau_ig, rho_pm_ig_taustar_ig, rho_pm_ig_a_ig, rho_pm_ig_zstar_ig, rho_pm_ig_astar_ig, rho_pm_ig_delta_ig, ...
             rho_px_ig_tau_ig, rho_px_ig_taustar_ig, rho_px_ig_a_ig, rho_px_ig_zstar_ig, rho_px_ig_astar_ig, rho_px_ig_delta_ig, ...
             rho_pstar_ig_tau_ig, rho_pstar_ig_taustar_ig, rho_pstar_ig_a_ig, rho_pstar_ig_zstar_ig, rho_pstar_ig_astar_ig, rho_pstar_ig_delta_ig, ...
             rho_rm_ig_tau_ig, rho_rm_ig_taustar_ig, rho_rm_ig_a_ig, rho_rm_ig_zstar_ig, rho_rm_ig_astar_ig, rho_rm_ig_delta_ig] ...
             = compute_equilibrium_foa_full( delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, ... 
             tau_ig_0, tau_star_ig_0, m_ig_0, E_s_0, p_s_0, Dsg, DX, DM, omega_star, sigma, eta, kappa, sigma_star, nu, epsilon, gammaQ, gammaX, gammaM, rhoX, rhoM, tol_conv, adj_inner_g0, adj_outter_g0, ... 
             ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig );  


    share_pm    = [rho_pm_ig_tau_ig    , rho_pm_ig_taustar_ig   ];
    share_px    = [rho_px_ig_tau_ig    , rho_px_ig_taustar_ig   ];
    share_rm    = [rho_rm_ig_tau_ig    , rho_rm_ig_taustar_ig   ];
    shareMOD_dW    = sparse([share_pm; share_rm; share_px]);

%% IV for welfare

    [Nobs, Nsh] = size(shareMOD_dW);
    Nobs_n = sum(sample_naive);

    adjGE_const = Nobs*avg_tariff/std_tariff;

%Control vectors   
    shareMOD_dWbar = sum(shareMOD_dW,2);
    C = [ones(Nobs,1), shareMOD_dWbar, indX, indM];

%Compute matrices for OLS coefs    
    MC_term1 = (C'*C)\C';    